11840. Sum of Squares with Segment Tree


Segment trees are extremely useful. In particular "Lazy Propagation" (for example allows one to compute sums over a range in O(lg(n)) and update ranges in O(lg(n)) as well. In this problem you will compute something much harder: 

The sum of squares over a range with range updates of 2 types:

1) increment in a range

2) set all numbers the same in a range.


Input. There will be t (t ≤ 25) test cases in the input file. First line of the input contains two positive integers, n (n ≤ 100,000) and q (q ≤ 100,000). The next line contains n integers, each at most 1000. Each of the next q lines starts with a number, which indicates the type of operation:

2 st nd – return the sum of the squares of the numbers with indices in [st, nd] {i.e., from st to nd inclusive} (1 ≤ stndn).

1 st nd x – add "x" to all numbers with indices in [st, nd] (1 ≤ stndn, and -1,000 ≤ x ≤ 1,000).

0 st nd x – set all numbers with indices in [st, nd] to "x" (1 ≤ stndn, and -1,000 ≤ x ≤ 1,000).


Output. For each test case output the “Case <caseno>:” in the first line and from the second line output the sum of squares for each operation of type 2.  Intermediate overflow will not occur with proper use of 64-bit signed integer.


Sample Input


4 5

1 2 3 4

2 1 4

0 3 4 1

2 1 4

1 3 4 1

2 1 4

1 1


2 1 1


Sample Output

Case 1:




Case 2:





ñòðóêòóðû äàííûõäåðåâî îòðåçêîâ


Àíàëèç àëãîðèòìà

 çàäà÷å ñëåäóåò ðåàëèçîâàòü äâå ìíîæåñòâåííûå îïåðàöèè: ñëîæåíèå è ïðèñâàèâàíèå.  êàæäîé âåðøèíå äåðåâà îòðåçêîâ îáúÿâèì äâå ïåðåìåííûå add è set äëÿ õðàíåíèÿ èíôîðìàöèè ïî îòëîæåííûì îïåðàöèÿì. È ñîîòâåòñòâåííî ïðè ïðîòàëêèâàíèè (îïåðàöèè push) îáðàáàòûâàåì èõ îòäåëüíî. Åøå ñëåäóåò ðåàëèçîâàòü ïîääåðæêó ñóììû êâàäðàòîâ íà îòðåçêå.

Ðàññìîòðèì îòðåçîê [i; j] ñ ÷èñëàìè ai, …, aj. Ïóñòü êî âñåì ÷èñëàì îòðåçêà äîáàâëåíî ÷èñëî v. Ñóììà íà îòðåçêå óâåëè÷èòñÿ íà (ji + 1) * v. Ðàññìîòðèì íà ñêîëüêî óâåëè÷èòñÿ ñóììà êâàäðàòîâ íà îòðåçêå. Ïîñëå óâåëè÷åíèÿ ÷èñåë íà v êâàäðàòû íà îòðåçêå ñòàíóò ðàâíûìè (ai + v)2, (ai+1 + v)2, …, (aj + v)2. Èõ ñóììà ðàâíà ( +  + … + ) + 2 * v * (ai + …+ aj) +  (ji + 1) * v2. Òî åñòü ïðè äîáàâëåíèè v êî âñåì ÷èñëàì îòðåçêà ê òåêóùåé ñóììå êâàäðàòîâ ñëåäóåò äîáàâèòü 2 * v * (ñóììà ÷èñåë íà îòðåçêå) +  (ji + 1) * v2. Ïîýòîìó âìåñòå ñ ïîääåðæêîé ñóììû êâàäðàòîâ íà îòðåçêå ñëåäóåò òàêæå ïîääåðæèâàòü è ñóììó íà îòðåçêå.


Ðåàëèçàöèÿ àëãîðèòìà


#include <cstdio>

#include <algorithm>

#define MAX 100010

#define NORMAL 0

#define ADD 1

#define SET 2

using namespace std;


struct node


  long long sum, sumSq, type, add;

} SegTree[4*MAX];


long long mas[MAX];


void build(long long *a, int Vertex, int Left, int Right)


  SegTree[Vertex].type = NORMAL;

  SegTree[Vertex].add = 0;

  if (Left == Right)


    SegTree[Vertex].sum = a[Left];

    SegTree[Vertex].sumSq = 1LL * a[Left] * a[Left];




    int Middle = (Left + Right) / 2;


    build (a, 2*Vertex, Left, Middle);

    build (a, 2*Vertex+1, Middle+1, Right);


    SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

    SegTree[Vertex].sumSq =

      SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;




void Push(int Vertex, int LeftPos, int Middle, int RightPos)


  if (SegTree[Vertex].type == SET)


    SegTree[2*Vertex].type = SegTree[2*Vertex+1].type = SegTree[Vertex].type;

    SegTree[2*Vertex].add = SegTree[2*Vertex+1].add = SegTree[Vertex].add;


    SegTree[2*Vertex].sum = (Middle - LeftPos + 1) * SegTree[Vertex].add;

    SegTree[2*Vertex].sumSq = (Middle - LeftPos + 1) *

                               SegTree[Vertex].add * SegTree[Vertex].add;


    SegTree[2*Vertex+1].sum = (RightPos - Middle) * SegTree[Vertex].add;

    SegTree[2*Vertex+1].sumSq = (RightPos - Middle) *

                                 SegTree[Vertex].add * SegTree[Vertex].add;


    SegTree[Vertex].add = 0;

    SegTree[Vertex].type = NORMAL;



  if (SegTree[Vertex].type == ADD)


    SegTree[2*Vertex].add += SegTree[Vertex].add;

    SegTree[2*Vertex].sumSq += (Middle - LeftPos + 1) *

            SegTree[Vertex].add * SegTree[Vertex].add +

            2LL * SegTree[Vertex].add * SegTree[2*Vertex].sum;

    SegTree[2*Vertex].sum += (Middle - LeftPos + 1) * SegTree[Vertex].add;

    if (SegTree[2*Vertex].type == NORMAL) SegTree[2*Vertex].type = ADD;

    if (SegTree[2*Vertex+1].type == NORMAL) SegTree[2*Vertex+1].type = ADD;


    SegTree[2*Vertex+1].add += SegTree[Vertex].add;

    SegTree[2*Vertex+1].sumSq += (RightPos - Middle) * SegTree[Vertex].add  *

                         SegTree[Vertex].add +

                         2LL * SegTree[Vertex].add * SegTree[2*Vertex+1].sum;

    SegTree[2*Vertex+1].sum += (RightPos - Middle) * SegTree[Vertex].add;

    SegTree[Vertex].add = 0;

    SegTree[Vertex].type = NORMAL;




void SetValue(int Vertex, int LeftPos, int RightPos, int Left,

              int Right, int Value)


  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))


    SegTree[Vertex].add = Value;

    SegTree[Vertex].type = SET;

    SegTree[Vertex].sum = (long long)(Right - Left + 1) * Value;

    SegTree[Vertex].sumSq = (long long)(Right - Left + 1) * Value * Value;




  int Middle = (LeftPos + RightPos) / 2;



  SetValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  SetValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);


  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;



void AddValue(int Vertex, int LeftPos, int RightPos,

              int Left, int Right, int Value)


  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))


    SegTree[Vertex].add += Value;

    if (SegTree[Vertex].type == NORMAL) SegTree[Vertex].type = ADD;


    SegTree[Vertex].sumSq += (long long)(Right - Left + 1) * Value * Value +

                             2LL * Value * SegTree[Vertex].sum;

    SegTree[Vertex].sum += (long long)(Right - Left + 1) * Value;




  int Middle = (LeftPos + RightPos) / 2;



  AddValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  AddValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);


  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;



long long SumSq(int Vertex, int LeftPos, int RightPos, int Left, int Right)


  if (Left > Right) return 0;

  if ((LeftPos == Left) && (RightPos == Right)) return SegTree[Vertex].sumSq;


  int Middle = (LeftPos + RightPos) / 2;



  return SumSq(2*Vertex, LeftPos, Middle, Left, min(Middle,Right)) +

         SumSq(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right);



int i, n, q, cs, tests, type, l, r, x;


int main(void)



  for(cs = 1; cs <= tests; cs++)


    scanf("%d %d",&n,&q);

    for(i = 1; i <= n; i++)




    printf("Case %d:\n",cs);





      if (type == 0)


        scanf("%d %d %d",&l,&r,&x);


      } else

      if (type == 1)    


        scanf("%d %d %d",&l,&r,&x);


      } else


        scanf("%d %d",&l,&r);





  return 0;



Ðåàëèçàöèÿ àëãîðèòìà – âòîðîé âàðèàíò


#include <cstdio>

#include <algorithm>

#define MAX 100010

#define INF 2100000000

using namespace std;


struct node


  long long sum, sumSq, add, set;

} SegTree[4*MAX];


long long mas[MAX];


void build(long long *a, int Vertex, int Left, int Right)


  if (Left == Right)


    SegTree[Vertex].sum = a[Left];

    SegTree[Vertex].sumSq = 1LL * a[Left] * a[Left];

    SegTree[Vertex].add = 0;

    SegTree[Vertex].set = INF;




    int Middle = (Left + Right) / 2;


    build (a, 2*Vertex, Left, Middle);

    build (a, 2*Vertex+1, Middle+1, Right);


    SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

    SegTree[Vertex].sumSq =

      SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

    SegTree[Vertex].add = 0;

    SegTree[Vertex].set = INF;




void Push(int Vertex, int LeftPos, int Middle, int RightPos)


  if (SegTree[Vertex].set != INF)


    SegTree[2*Vertex].set = SegTree[Vertex].set;

    SegTree[2*Vertex].add = 0;

    SegTree[2*Vertex].sum = (Middle - LeftPos + 1) * SegTree[Vertex].set;

    SegTree[2*Vertex].sumSq =

      (Middle - LeftPos + 1) * SegTree[Vertex].set * SegTree[Vertex].set;


    SegTree[2*Vertex+1].set = SegTree[Vertex].set;

    SegTree[2*Vertex+1].add = 0;

    SegTree[2*Vertex+1].sum = (RightPos - Middle) * SegTree[Vertex].set;

    SegTree[2*Vertex+1].sumSq = (RightPos - Middle) *

                                SegTree[Vertex].set * SegTree[Vertex].set;

    SegTree[Vertex].set = INF;



  if (SegTree[Vertex].add != 0)


    SegTree[2*Vertex].add += SegTree[Vertex].add;

    SegTree[2*Vertex].sumSq += (Middle - LeftPos + 1) * SegTree[Vertex].add

                               * SegTree[Vertex].add +

                         2LL * SegTree[Vertex].add * SegTree[2*Vertex].sum;

    SegTree[2*Vertex].sum += (Middle - LeftPos + 1) * SegTree[Vertex].add;


    SegTree[2*Vertex+1].add += SegTree[Vertex].add;

    SegTree[2*Vertex+1].sumSq += (RightPos - Middle) * SegTree[Vertex].add  *

                                 SegTree[Vertex].add +

                       2LL * SegTree[Vertex].add * SegTree[2*Vertex+1].sum;

    SegTree[2*Vertex+1].sum += (RightPos - Middle) * SegTree[Vertex].add;

    SegTree[Vertex].add = 0;




void SetValue(int Vertex, int LeftPos, int RightPos,

              int Left, int Right, int Value)


  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))


    SegTree[Vertex].add = 0;

    SegTree[Vertex].set = Value;

    SegTree[Vertex].sum = (long long)(Right - Left + 1) * Value;

    SegTree[Vertex].sumSq = (long long)(Right - Left + 1) * Value * Value;




  int Middle = (LeftPos + RightPos) / 2;



  SetValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  SetValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);


  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;



void AddValue(int Vertex, int LeftPos, int RightPos,

              int Left, int Right, int Value)


  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))


    SegTree[Vertex].add += Value;

    SegTree[Vertex].sumSq += (long long)(Right - Left + 1) * Value * Value +

                             2LL * Value * SegTree[Vertex].sum;

    SegTree[Vertex].sum += (long long)(Right - Left + 1) * Value;




  int Middle = (LeftPos + RightPos) / 2;



  AddValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  AddValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);


  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;



long long SumSq(int Vertex, int LeftPos, int RightPos, int Left, int Right)


  if (Left > Right) return 0;

  if ((LeftPos == Left) && (RightPos == Right)) return SegTree[Vertex].sumSq;


  int Middle = (LeftPos + RightPos) / 2;



  return SumSq(2*Vertex, LeftPos, Middle, Left, min(Middle,Right)) +

         SumSq(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right);



int i, n, q, cs, tests, type, l, r, x;


int main(void)



  for(cs = 1; cs <= tests; cs++)


    scanf("%d %d",&n,&q);

    for(i = 1; i <= n; i++)




    printf("Case %d:\n",cs);





      if (type == 0)


        scanf("%d %d %d",&l,&r,&x);


      } else

      if (type == 1)    


        scanf("%d %d %d",&l,&r,&x);


      } else


        scanf("%d %d",&l,&r);





  return 0;
